#!/usr/bin/env python
# dpp_retriever.py

import os
import json
import logging
import faiss
import hydra
import hydra.utils as hu
import numpy as np
import torch
import tqdm
from functools import partial
from transformers import set_seed
from torch.utils.data import DataLoader
from src.utils.misc import parallel_run
from src.utils.collators import DataCollatorWithPaddingAndCuda
from src.models.biencoder import BiEncoder


os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")

logger = logging.getLogger(__name__)


def init_index(idx):
    global index_global
    index_global = idx

def fast_map_dpp(L: np.ndarray, k: int) -> list[int]:
    """
    Greedy MAP inference for a k‐DPP defined by kernel L.
    Returns the indices of the selected subset of size k.
    """
    n = L.shape[0]
    C = np.zeros((k, n), dtype=L.dtype)
    d = np.copy(np.diag(L))
    selected: list[int] = []

    for i in range(k):
        j = int(np.argmax(d))
        selected.append(j)
        if i == k - 1:
            break

        sqrt_dj = np.sqrt(d[j])
        for x in range(n):
            if i == 0:
                C[i, x] = L[j, x] / sqrt_dj
            else:
                C[i, x] = (L[j, x] - np.dot(C[:i, j], C[:i, x])) / sqrt_dj

        d -= C[i] ** 2

    return selected

def k_dpp_sampling(kernel_matrix: np.ndarray,
                   rel_scores: np.ndarray,
                   num_ice: int,
                   num_candidates: int,
                   pre_results: list[list[int]] = None) -> list[list[int]]:
    """
    Generate `num_candidates` subsets of size `num_ice` by unning perturbed kernel greedy MAP a few times.
    """
    n = kernel_matrix.shape[0]
    results = list(pre_results) if pre_results is not None else []
    attempts, max_attempts = 0, num_candidates * 5

    while len(results) < num_candidates and attempts < max_attempts:
        noise = np.random.normal(scale=1e-6, size=kernel_matrix.shape)
        noise = (noise + noise.T) * 0.5
        L_noisy = kernel_matrix + noise

        subset = fast_map_dpp(L_noisy, num_ice)
        if subset not in results:
            results.append(subset)
        attempts += 1

    # fallback to pure random if needed
    while len(results) < num_candidates:
        rnd = sorted(np.random.choice(n, num_ice, replace=False).tolist())
        results.append(rnd)

    return results

def get_kernel(query_emb: np.ndarray,
               candidates: list[int],
               scale_factor: float):
    """
    Build the conditional DPP kernel:
      L_ij = exp(sim(query, cand_i)/scale) * sim(cand_i, cand_j) * exp(sim(query, cand_j)/scale)
    where sim is cosine similarity scaled into [0,1].
    """
    reps = np.stack([index_global.index.reconstruct(i) for i in candidates], axis=0)
    reps /= np.linalg.norm(reps, axis=1, keepdims=True)
    q = query_emb / np.linalg.norm(query_emb)
    rel = (q @ reps.T + 1) / 2
    rel = np.exp((rel - rel.max()) / (2 * scale_factor))
    sim = (reps @ reps.T + 1) / 2
    L = rel[None, :] * sim * rel[:, None]
    return L, rel

def dpp(entry: dict,
        num_candidates: int,
        num_ice: int,
        dpp_topk: int,
        scale_factor: float):
    """
    Perform DPP based retrieval for a single query entry.
    """
    # 1) get top‐dpp_topk nearest ids
    q_emb = np.expand_dims(entry['embed'], axis=0)
    _, ids = index_global.search(q_emb.astype('float32'), dpp_topk)
    cands = ids[0].tolist()

    # 2) build conditional DPP kernel + relevance scores
    L, rel = get_kernel(entry['embed'], cands, scale_factor)

    # 3) greedy MAP selection
    base_subset = sorted(fast_map_dpp(L, num_ice))
    all_sets = [base_subset]

    # 4) if multiple candidates requested, run k‐DPP sampling
    if num_candidates > 1:
        all_sets = k_dpp_sampling(L, rel, num_ice, num_candidates, pre_results=all_sets)

    # 5) attach to original entry
    out = entry['entry']
    out['ctxs'] = [cands[i] for i in all_sets[0]]
    out['ctxs_candidates'] = [[cands[i] for i in s] for s in all_sets]
    return out

# ------------------------------------------------------------------------------
# DPPRetriever
# ------------------------------------------------------------------------------
class DPPRetriever:
    def __init__(self, cfg):
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"

        # Query dataset reader
        self.reader = hu.instantiate(cfg.dataset_reader)
        co = DataCollatorWithPaddingAndCuda(
            tokenizer=self.reader.tokenizer,
            device=self.device
        )
        self.dl = DataLoader(self.reader,
                             batch_size=cfg.batch_size,
                             collate_fn=co)

        # Model setup
        mc = hu.instantiate(cfg.model_config)
        if cfg.pretrained_model_path:
            logger.info(f"Loading model from {cfg.pretrained_model_path}")
            self.model = BiEncoder.from_pretrained(cfg.pretrained_model_path,
                                                   config=mc)
        else:
            self.model = BiEncoder(mc)
        self.model.to(self.device).eval()

        # Hyperparameters
        self.output_file    = cfg.output_file
        self.num_candidates = cfg.num_candidates
        self.num_ice        = cfg.num_ice
        self.dpp_topk       = cfg.dpp_topk
        self.scale_factor   = getattr(self.model, "scale_factor", 0.1)

        # Build FAISS index over the "index" split
        self.index = self._build_index(cfg)

    def _build_index(self, cfg):
        logger.info("Building FAISS index…")
        idx_reader = hu.instantiate(cfg.index_reader)
        co = DataCollatorWithPaddingAndCuda(
            tokenizer=idx_reader.tokenizer,
            device=self.device
        )
        dl = DataLoader(idx_reader,
                        batch_size=cfg.batch_size,
                        collate_fn=co)

        ids_all, embs_all = [], []
        for batch in tqdm.tqdm(dl):
            with torch.no_grad():
                meta = batch.pop("metadata")
                embs = self.model.encode(**batch, encode_ctx=True)
            arr = embs.cpu().numpy().astype("float32")
            embs_all.append(arr)
            ids_all.extend([m["id"] for m in meta.data])

        embs_all = np.vstack(embs_all)
        index = faiss.IndexIDMap(faiss.IndexFlatIP(embs_all.shape[1]))
        index.add_with_ids(embs_all, np.array(ids_all))
        return index

    def _encode_queries(self):
        results = []
        for batch in tqdm.tqdm(self.dl):
            with torch.no_grad():
                meta = batch.pop("metadata")
                embs = self.model.encode(**batch, encode_ctx=False)
            arr = embs.cpu().numpy()
            for e, m in zip(arr, meta.data):
                results.append({"embed": e, "metadata": m})
        return results

    def find(self):
        # 1) encode all queries
        queries = self._encode_queries()

        # 2) attach full original entries so dpp() can return label, etc.
        for q in queries:
            q['entry'] = self.reader.dataset_wrapper[q['metadata']['id']]

        # 3) run DPP in parallel
        runner = partial(
            dpp,
            num_candidates=self.num_candidates,
            num_ice=self.num_ice,
            dpp_topk=self.dpp_topk,
            scale_factor=self.scale_factor
        )
        data = parallel_run(
            func=runner,
            args_list=queries,
            initializer=init_index,
            initargs=(self.index,)
        )

        # 4) save results
        with open(self.output_file, "w") as f:
            json.dump(data, f)

@hydra.main(config_path="configs", config_name="dense_retriever")
def main(cfg):
    set_seed(42)
    dr = DPPRetriever(cfg)
    dr.find()

if __name__ == "__main__":
    main()
